package backtest

import (
	"fmt"
	"math"
	"strings"
)

const epsilon = 1e-8

type position struct {
	Symbol           string
	Side             string
	Quantity         float64
	EntryPrice       float64
	Leverage         int
	Margin           float64
	Notional         float64
	LiquidationPrice float64
	OpenTime         int64
	StopLoss         float64 // 止损价格，0 表示未设置
	TakeProfit       float64 // 止盈价格，0 表示未设置
}

type BacktestAccount struct {
	initialBalance float64
	cash           float64
	feeRate        float64
	slippageRate   float64
	positions      map[string]*position
	realizedPnL    float64
}

func NewBacktestAccount(initialBalance, feeBps, slippageBps float64) *BacktestAccount {
	return &BacktestAccount{
		initialBalance: initialBalance,
		cash:           initialBalance,
		feeRate:        feeBps / 10000.0,
		slippageRate:   slippageBps / 10000.0,
		positions:      make(map[string]*position),
	}
}

func positionKey(symbol, side string) string {
	return strings.ToUpper(symbol) + ":" + side
}

func (acc *BacktestAccount) ensurePosition(symbol, side string) *position {
	key := positionKey(symbol, side)
	if pos, ok := acc.positions[key]; ok {
		return pos
	}
	pos := &position{Symbol: strings.ToUpper(symbol), Side: side}
	acc.positions[key] = pos
	return pos
}

func (acc *BacktestAccount) removePosition(pos *position) {
	key := positionKey(pos.Symbol, pos.Side)
	delete(acc.positions, key)
}

func (acc *BacktestAccount) Open(symbol, side string, quantity float64, leverage int, price, stopLoss, takeProfit float64, ts int64) (*position, float64, float64, error) {
	if quantity <= 0 {
		return nil, 0, 0, fmt.Errorf("quantity must be positive")
	}
	if leverage <= 0 {
		return nil, 0, 0, fmt.Errorf("leverage must be positive")
	}

	// 风险保护：限制杠杆上限
	const MaxLeverage = 100
	if leverage > MaxLeverage {
		return nil, 0, 0, fmt.Errorf("leverage %d exceeds maximum allowed leverage of %d", leverage, MaxLeverage)
	}

	// 风险保护：限制最大持仓数量
	const MaxPositions = 20
	if len(acc.positions) >= MaxPositions {
		return nil, 0, 0, fmt.Errorf("maximum position count (%d) reached, cannot open new position", MaxPositions)
	}

	execPrice := applySlippage(price, acc.slippageRate, side, true)
	notional := execPrice * quantity
	margin := notional / float64(leverage)
	fee := notional * acc.feeRate

	// 风险保护：单笔交易名义价值不能超过账户总资产的50倍
	totalEquity, _, _ := acc.TotalEquity(map[string]float64{symbol: price})
	const MaxNotionalMultiplier = 50.0
	if notional > totalEquity*MaxNotionalMultiplier {
		return nil, 0, 0, fmt.Errorf("notional value %.2f exceeds maximum allowed (%.2f x %.0fx = %.2f)",
			notional, totalEquity, MaxNotionalMultiplier, totalEquity*MaxNotionalMultiplier)
	}

	if margin+fee > acc.cash+epsilon {
		return nil, 0, 0, fmt.Errorf("insufficient cash: need %.2f", margin+fee)
	}

	acc.cash -= margin + fee

	pos := acc.ensurePosition(symbol, side)

	if pos.Quantity < epsilon {
		pos.Quantity = quantity
		pos.EntryPrice = execPrice
		pos.Leverage = leverage
		pos.Margin = margin
		pos.Notional = notional
		pos.OpenTime = ts
		pos.LiquidationPrice = computeLiquidation(execPrice, leverage, side)
		pos.StopLoss = stopLoss
		pos.TakeProfit = takeProfit
	} else {
		if leverage != pos.Leverage {
			// 采用权重平均杠杆（近似）
			weightedMargin := pos.Margin + margin
			pos.Leverage = int(math.Round((pos.Notional + notional) / weightedMargin))
		}
		pos.Notional += notional
		pos.Margin += margin
		pos.EntryPrice = ((pos.EntryPrice * pos.Quantity) + execPrice*quantity) / (pos.Quantity + quantity)
		pos.Quantity += quantity
		pos.LiquidationPrice = computeLiquidation(pos.EntryPrice, pos.Leverage, side)
		// 加仓时更新止损止盈（如果提供了新值）
		if stopLoss > 0 {
			pos.StopLoss = stopLoss
		}
		if takeProfit > 0 {
			pos.TakeProfit = takeProfit
		}
	}

	return pos, fee, execPrice, nil
}

func (acc *BacktestAccount) Close(symbol, side string, quantity float64, price float64) (float64, float64, float64, error) {
	key := positionKey(symbol, side)
	pos, ok := acc.positions[key]
	if !ok || pos.Quantity <= epsilon {
		return 0, 0, 0, fmt.Errorf("no active %s position for %s", side, symbol)
	}

	if quantity <= 0 || quantity > pos.Quantity+epsilon {
		if math.Abs(quantity) <= epsilon {
			quantity = pos.Quantity
		} else {
			return 0, 0, 0, fmt.Errorf("invalid close quantity")
		}
	}

	execPrice := applySlippage(price, acc.slippageRate, side, false)
	notional := execPrice * quantity
	fee := notional * acc.feeRate

	realized := realizedPnL(pos, quantity, execPrice)

	marginPortion := pos.Margin * (quantity / pos.Quantity)
	acc.cash += marginPortion + realized - fee
	acc.realizedPnL += realized - fee

	pos.Quantity -= quantity
	pos.Notional -= pos.EntryPrice * quantity // Fixed: use entry price, not close price
	pos.Margin -= marginPortion

	if pos.Quantity <= epsilon {
		acc.removePosition(pos)
	}

	return realized, fee, execPrice, nil
}

// UpdateStopLoss 更新指定持仓的止损价格
func (acc *BacktestAccount) UpdateStopLoss(symbol, side string, newStopLoss float64) error {
	key := positionKey(symbol, side)
	pos, ok := acc.positions[key]
	if !ok || pos.Quantity <= epsilon {
		return fmt.Errorf("no active %s position for %s", side, symbol)
	}
	pos.StopLoss = newStopLoss
	return nil
}

// UpdateTakeProfit 更新指定持仓的止盈价格
func (acc *BacktestAccount) UpdateTakeProfit(symbol, side string, newTakeProfit float64) error {
	key := positionKey(symbol, side)
	pos, ok := acc.positions[key]
	if !ok || pos.Quantity <= epsilon {
		return fmt.Errorf("no active %s position for %s", side, symbol)
	}
	pos.TakeProfit = newTakeProfit
	return nil
}

func (acc *BacktestAccount) TotalEquity(priceMap map[string]float64) (float64, float64, map[string]float64) {
	unrealized := 0.0
	margin := 0.0
	perSymbol := make(map[string]float64)
	for _, pos := range acc.positions {
		price := priceMap[pos.Symbol]
		pnl := unrealizedPnL(pos, price)
		unrealized += pnl
		margin += pos.Margin
		perSymbol[pos.Symbol+":"+pos.Side] = pnl
	}
	return acc.cash + margin + unrealized, unrealized, perSymbol
}

func applySlippage(price float64, rate float64, side string, isOpen bool) float64 {
	if rate <= 0 {
		return price
	}
	adjust := 1.0
	if side == "long" {
		if isOpen {
			adjust += rate
		} else {
			adjust -= rate
		}
	} else {
		if isOpen {
			adjust -= rate
		} else {
			adjust += rate
		}
	}
	return price * adjust
}

func computeLiquidation(entry float64, leverage int, side string) float64 {
	if leverage <= 0 {
		return 0
	}
	lev := float64(leverage)
	if side == "long" {
		return entry * (1.0 - 1.0/lev)
	}
	return entry * (1.0 + 1.0/lev)
}

func realizedPnL(pos *position, qty, price float64) float64 {
	if pos.Side == "long" {
		return (price - pos.EntryPrice) * qty
	}
	return (pos.EntryPrice - price) * qty
}

func unrealizedPnL(pos *position, price float64) float64 {
	if pos.Side == "long" {
		return (price - pos.EntryPrice) * pos.Quantity
	}
	return (pos.EntryPrice - price) * pos.Quantity
}

func (acc *BacktestAccount) Positions() []*position {
	list := make([]*position, 0, len(acc.positions))
	for _, pos := range acc.positions {
		list = append(list, pos)
	}
	return list
}

func (acc *BacktestAccount) positionLeverage(symbol, side string) int {
	key := positionKey(symbol, side)
	if pos, ok := acc.positions[key]; ok && pos.Quantity > epsilon {
		return pos.Leverage
	}
	return 0
}

// StopLossTakeProfitTrigger 表示一个止损/止盈触发事件
type StopLossTakeProfitTrigger struct {
	Position    *position
	TriggerType string  // "stop_loss" 或 "take_profit"
	TriggerPrice float64
	CurrentPrice float64
	Reason      string
}

// CheckStopLossTakeProfit 检查所有持仓的止损止盈条件，返回需要触发的持仓
func (acc *BacktestAccount) CheckStopLossTakeProfit(priceMap map[string]float64) []StopLossTakeProfitTrigger {
	var triggers []StopLossTakeProfitTrigger

	for _, pos := range acc.positions {
		if pos.Quantity <= epsilon {
			continue
		}

		currentPrice, ok := priceMap[pos.Symbol]
		if !ok || currentPrice <= 0 {
			continue
		}

		// 检查多头止损/止盈
		if pos.Side == "long" {
			// 止损检查：价格跌破止损价
			if pos.StopLoss > 0 && currentPrice <= pos.StopLoss {
				triggers = append(triggers, StopLossTakeProfitTrigger{
					Position:     pos,
					TriggerType:  "stop_loss",
					TriggerPrice: pos.StopLoss,
					CurrentPrice: currentPrice,
					Reason:       fmt.Sprintf("多头止损触发: %.4f <= %.4f", currentPrice, pos.StopLoss),
				})
				continue // 止损优先，不再检查止盈
			}

			// 止盈检查：价格涨破止盈价
			if pos.TakeProfit > 0 && currentPrice >= pos.TakeProfit {
				triggers = append(triggers, StopLossTakeProfitTrigger{
					Position:     pos,
					TriggerType:  "take_profit",
					TriggerPrice: pos.TakeProfit,
					CurrentPrice: currentPrice,
					Reason:       fmt.Sprintf("多头止盈触发: %.4f >= %.4f", currentPrice, pos.TakeProfit),
				})
			}
		}

		// 检查空头止损/止盈
		if pos.Side == "short" {
			// 止损检查：价格涨破止损价
			if pos.StopLoss > 0 && currentPrice >= pos.StopLoss {
				triggers = append(triggers, StopLossTakeProfitTrigger{
					Position:     pos,
					TriggerType:  "stop_loss",
					TriggerPrice: pos.StopLoss,
					CurrentPrice: currentPrice,
					Reason:       fmt.Sprintf("空头止损触发: %.4f >= %.4f", currentPrice, pos.StopLoss),
				})
				continue
			}

			// 止盈检查：价格跌破止盈价
			if pos.TakeProfit > 0 && currentPrice <= pos.TakeProfit {
				triggers = append(triggers, StopLossTakeProfitTrigger{
					Position:     pos,
					TriggerType:  "take_profit",
					TriggerPrice: pos.TakeProfit,
					CurrentPrice: currentPrice,
					Reason:       fmt.Sprintf("空头止盈触发: %.4f <= %.4f", currentPrice, pos.TakeProfit),
				})
			}
		}
	}

	return triggers
}

func (acc *BacktestAccount) Cash() float64 {
	return acc.cash
}

func (acc *BacktestAccount) InitialBalance() float64 {
	return acc.initialBalance
}

func (acc *BacktestAccount) RealizedPnL() float64 {
	return acc.realizedPnL
}

// RestoreFromSnapshots 用于从检查点恢复账户状态。
func (acc *BacktestAccount) RestoreFromSnapshots(cash float64, realized float64, snaps []PositionSnapshot) {
	acc.cash = cash
	acc.realizedPnL = realized
	acc.positions = make(map[string]*position)
	for _, snap := range snaps {
		pos := &position{
			Symbol:           snap.Symbol,
			Side:             snap.Side,
			Quantity:         snap.Quantity,
			EntryPrice:       snap.AvgPrice,
			Leverage:         snap.Leverage,
			Margin:           snap.MarginUsed,
			Notional:         snap.Quantity * snap.AvgPrice,
			LiquidationPrice: snap.LiquidationPrice,
			OpenTime:         snap.OpenTime,
		}
		key := positionKey(pos.Symbol, pos.Side)
		acc.positions[key] = pos
	}
}
